{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook [3]: Training the reader on the SQuAD v1.1 dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows how to fine-tune a pre-trained BERT model on the SQuAD."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***Note:*** *To run this notebook you will need to have access to GPU. The fine-tuning of the Reader was done with an AWS EC2 p3.2xlarge machine (GPU Tesla V100 16GB). It took about 2 hours to complete (2 epochs on SQuAD 1.1 train was enough to achieve SOTA results on SQuAD 1.1 dev).*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-20T13:45:44.624084Z",
     "start_time": "2019-07-20T13:45:41.394789Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/supercalculateur/source/andre/cdqa-dev/env-cdqa/lib/python3.6/site-packages/tqdm/autonotebook/__init__.py:18: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n",
      "I1120 11:43:47.615704 140657575868224 file_utils.py:39] PyTorch version 1.2.0 available.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "import joblib\n",
    "from cdqa.reader import BertProcessor, BertQA\n",
    "from cdqa.utils.download import download_squad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download SQuAD datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-20T13:46:13.505754Z",
     "start_time": "2019-07-20T13:46:00.589821Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading SQuAD v1.1 data...\n",
      "train-v1.1.json already downloaded\n",
      "dev-v1.1.json already downloaded\n",
      "\n",
      "Downloading SQuAD v2.0 data...\n",
      "train-v2.0.json already downloaded\n",
      "dev-v2.0.json already downloaded\n"
     ]
    }
   ],
   "source": [
    "download_squad(dir='./data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocess SQuAD 1.1 examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-20T13:58:36.512980Z",
     "start_time": "2019-07-20T13:46:44.792080Z"
    },
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I1120 11:43:48.194295 140657575868224 tokenization_utils.py:375] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/supercalculateur/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
     ]
    }
   ],
   "source": [
    "train_processor = BertProcessor(do_lower_case=True, is_training=True)\n",
    "train_examples, train_features = train_processor.fit_transform(X='./data/SQuAD_1.1/train-v1.1.json')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I1120 11:43:53.164162 140657575868224 configuration_utils.py:152] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/supercalculateur/.cache/torch/transformers/distributed_-1/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c\n",
      "I1120 11:43:53.165523 140657575868224 configuration_utils.py:169] Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"finetuning_task\": null,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"is_decoder\": false,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"num_labels\": 2,\n",
      "  \"output_attentions\": false,\n",
      "  \"output_hidden_states\": false,\n",
      "  \"output_past\": true,\n",
      "  \"pruned_heads\": {},\n",
      "  \"torchscript\": false,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_bfloat16\": false,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "I1120 11:43:53.591548 140657575868224 modeling_utils.py:383] loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin from cache at /home/supercalculateur/.cache/torch/transformers/distributed_-1/aa1ef1aede4482d0dbcd4d52baad8ae300e60902e88fcb0bebdec09afd232066.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n",
      "I1120 11:43:55.430284 140657575868224 modeling_utils.py:453] Weights of BertForQuestionAnswering not initialized from pretrained model: ['qa_outputs.weight', 'qa_outputs.bias']\n",
      "I1120 11:43:55.431005 140657575868224 modeling_utils.py:456] Weights from pretrained model not used in BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9fb44d7dc6854474a6dc36ea50168573",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=2, style=ProgressStyle(description_width='initial…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f50802fcf50043f1bfa008c9b911d3df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Iteration', max=4, style=ProgressStyle(description_width='ini…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c8eaa69941804829bc7c2c984487f7d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Iteration', max=4, style=ProgressStyle(description_width='ini…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "BertQA(adam_epsilon=1e-08, bert_model='bert-base-uncased', do_lower_case=True,\n",
       "       fp16=False, gradient_accumulation_steps=1, learning_rate=3e-05,\n",
       "       local_rank=-1, loss_scale=0, max_answer_length=30, n_best_size=20,\n",
       "       no_cuda=False, null_score_diff_threshold=0.0, num_train_epochs=2,\n",
       "       output_dir='models', predict_batch_size=8, seed=42, server_ip='',\n",
       "       server_port='', train_batch_size=12, verbose_logging=False,\n",
       "       version_2_with_negative=False, warmup_proportion=0.1, warmup_steps=0)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reader = BertQA(train_batch_size=12,\n",
    "                learning_rate=3e-5,\n",
    "                num_train_epochs=2,\n",
    "                do_lower_case=True,\n",
    "                output_dir='models')\n",
    "\n",
    "reader.fit(X=(train_examples, train_features))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Send model to CPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reader.model.to('cpu')\n",
    "reader.device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save model locally"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "joblib.dump(reader, os.path.join(reader.output_dir, 'bert_qa.joblib'))"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
